{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import tensorflow as tf\n",
    "sys.path.append(\"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from batcher import batcher\n",
    "from pgn import PGN\n",
    "from tqdm import tqdm\n",
    "\n",
    "from utils.test_helper import beam_decode\n",
    "from utils.config import CKPT_DIR\n",
    "from utils.params import get_default_params\n",
    "from utils.saveLoader import Vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.test import test, test_and_save\n",
    "from utils.test_helper import greedy_decode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = get_default_params()\n",
    "params[\"mode\"] = \"test\"\n",
    "params[\"batch_size\"] = 4\n",
    "params[\"num_to_test\"] = 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 测试贪心预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating the checkpoint manager\n",
      "Initializing from scratch.\n"
     ]
    }
   ],
   "source": [
    "model = PGN(params)\n",
    "print(\"Creating the checkpoint manager\")\n",
    "checkpoint = tf.train.Checkpoint(Seq2Seq=model)\n",
    "checkpoint_manager = tf.train.CheckpointManager(checkpoint, CKPT_DIR, max_to_keep=5)\n",
    "checkpoint.restore(checkpoint_manager.latest_checkpoint)\n",
    "if checkpoint_manager.latest_checkpoint:\n",
    "    print(\"Restored from {}\".format(checkpoint_manager.latest_checkpoint))\n",
    "else:\n",
    "    print(\"Initializing from scratch.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x1dba8de1508>"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "checkpoint.restore(os.path.join(CKPT_DIR,\"ckpt-2\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mode=test\n"
     ]
    }
   ],
   "source": [
    "vocab = Vocab(params[\"vocab_path\"], params[\"vocab_size\"])\n",
    "ds = batcher(vocab, params)\n",
    "batch = next(iter(ds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ， ，',\n",
       " '<PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>',\n",
       " '<PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>',\n",
       " '<PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>']"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res = greedy_decode(model, batch, vocab, params)\n",
    "res"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以下是探究过程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 初始化mask\n",
    "start_index = vocab.word_to_id(vocab.START_DECODING)\n",
    "stop_index = vocab.word_to_id(vocab.STOP_DECODING)\n",
    "unk_index = vocab.word_to_id(vocab.UNKNOWN_TOKEN)\n",
    "\n",
    "batch_size = params[\"batch_size\"]  # 一个一个预测"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试集输入x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# enc_input shape (batch_size, enc_len)\n",
    "enc_input = batch[0][\"enc_input\"]\n",
    "\n",
    "# enc_output shape (batch_size, enc_len, enc_unit)\n",
    "# enc_hidden shape (batch_size, enc_unit)\n",
    "enc_output, enc_hidden = model.call_encoder(enc_input)  # update\n",
    "\n",
    "# dec_input shape (batch_size, 1)\n",
    "# dec_hidden shape (batch_size, enc_unit)\n",
    "dec_input = tf.expand_dims([start_index] * batch_size, 1)  # update\n",
    "dec_hidden = enc_hidden  # update\n",
    "\n",
    "# enc_extended_inp shape (batch_size, enc_len)\n",
    "# batch_oov_len shape (, )\n",
    "# enc_pad_mask shape (batch_size, enc_len)\n",
    "enc_extended_inp = batch[0][\"extended_enc_input\"]  # constant\n",
    "batch_oov_len = batch[0][\"max_oov_len\"]  # constant\n",
    "enc_pad_mask = batch[0][\"sample_encoder_pad_mask\"]  # constant\n",
    "prev_coverage = None  # update"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pred_id= 38 1\n",
      "pred_id= 0 2\n",
      "pred_id= 38 3\n",
      "pred_id= 0 4\n",
      "pred_id= 38 5\n",
      "pred_id= 0 6\n",
      "pred_id= 38 7\n",
      "pred_id= 0 8\n",
      "pred_id= 38 9\n",
      "pred_id= 0 10\n",
      "pred_id= 38 11\n",
      "pred_id= 0 12\n",
      "pred_id= 38 13\n",
      "pred_id= 0 14\n",
      "pred_id= 38 15\n",
      "pred_id= 0 16\n",
      "pred_id= 38 17\n",
      "pred_id= 0 18\n",
      "pred_id= 38 19\n",
      "pred_id= 0 20\n",
      "pred_id= 38 21\n",
      "pred_id= 0 22\n",
      "pred_id= 38 23\n",
      "pred_id= 0 24\n",
      "pred_id= 38 25\n",
      "pred_id= 0 26\n",
      "pred_id= 38 27\n",
      "pred_id= 0 28\n",
      "pred_id= 38 29\n",
      "pred_id= 0 30\n",
      "pred_id= 38 31\n",
      "pred_id= 0 32\n",
      "pred_id= 38 33\n",
      "pred_id= 0 34\n",
      "pred_id= 38 35\n",
      "pred_id= 0 36\n",
      "pred_id= 38 37\n",
      "pred_id= 0 38\n",
      "pred_id= 38 39\n",
      "pred_id= 0 40\n",
      "pred_id= 38 41\n",
      "pred_id= 0 42\n",
      "pred_id= 38 43\n",
      "pred_id= 0 44\n",
      "pred_id= 38 45\n",
      "pred_id= 0 46\n",
      "pred_id= 38 47\n",
      "pred_id= 0 48\n",
      "pred_id= 38 49\n",
      "pred_id= 0 50\n",
      "pred_id= 38 51\n",
      "pred_id= 0 52\n"
     ]
    }
   ],
   "source": [
    "# 遍历步数\n",
    "predicts=[''] * batch_size\n",
    "steps = 0  # initial ste\n",
    "while steps < params['max_dec_len']:\n",
    "    final_preds, dec_hidden, context_vector, \\\n",
    "    attention_weights, p_gens, coverage_ret = model.call_decoder_onestep(dec_input,  # update\n",
    "                                                                        dec_hidden,  # update\n",
    "                                                                        enc_output,  # constant\n",
    "                                                                        enc_extended_inp,  # constant\n",
    "                                                                        batch_oov_len,  # constant\n",
    "                                                                        enc_pad_mask,  # constant\n",
    "                                                                        use_coverage=True,  # constant\n",
    "                                                                        prev_coverage=prev_coverage  # update\n",
    "                                                                        )\n",
    "    predicted_ids = tf.argmax(final_preds, axis=2).numpy()\n",
    "    dec_input = tf.cast(predicted_ids, dtype=dec_input.dtype)  # update\n",
    "    prev_coverage = coverage_ret  # update\n",
    "    \n",
    "    for index,pred_id in enumerate(predicted_ids):\n",
    "        # exmp: pred_id [458]\n",
    "        pred_id = int(pred_id)\n",
    "        if pred_id < vocab.count:\n",
    "            predicts[index]+= vocab.id2word[pred_id] + ' '\n",
    "        else:\n",
    "            # if pred_id is oovs index\n",
    "            predicts[index]+= batch[0][\"article_oovs\"][pred_id-vocab.count]\n",
    "    steps += 1\n",
    "    # print(steps)\n",
    "    \n",
    "results=[]\n",
    "for predict in predicts:\n",
    "    # 去掉句子前后空格\n",
    "    predict=predict.strip()\n",
    "    # 句子小于max len就结束了 截断\n",
    "    if '<STOP>' in predict:\n",
    "        # 截断stop\n",
    "        predict=predict[:predict.index('<STOP>')]\n",
    "    # 保存结果\n",
    "    results.append(predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ，',\n",
       " '机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ，',\n",
       " '机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ， 机油 ，']"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 测试beam_decode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = get_default_params()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen = test(params)\n",
    "hy = next(gen)\n",
    "hy.abstract"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "自己测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|                                                                                            | 0/2 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Building the model ...\n",
      "Creating the vocab ...\n",
      "Creating the batcher ...\n",
      "Creating the checkpoint manager\n",
      "Restored from E:\\GitHub\\QA-abstract-and-reasoning\\data\\checkpoints\\training_checkpoints\\test_model-1\n",
      "Model restored\n",
      "mode=test\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:16<00:00,  8.32s/it]\n"
     ]
    }
   ],
   "source": [
    "test_and_save(params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:tf2.0]",
   "language": "python",
   "name": "conda-env-tf2.0-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
